import math
train = {  # 定义训练集
    1: {'outlook': 'sunny', 'temp': 'hot', 'hum': 'high', 'wind': 'weak', 'play': 'no'},
    2: {'outlook': 'sunny', 'temp': 'hot', 'hum': 'high', 'wind': 'strong', 'play': 'no'},
    3: {'outlook': 'overcast', 'temp': 'hot', 'hum': 'high', 'wind': 'weak', 'play': 'yes'},
    4: {'outlook': 'rain', 'temp': 'mild', 'hum': 'high', 'wind': 'weak', 'play': 'yes'},
    5: {'outlook': 'rain', 'temp': 'cool', 'hum': 'normal', 'wind': 'weak', 'play': 'yes'},
    6: {'outlook': 'rain', 'temp': 'cool', 'hum': 'normal', 'wind': 'strong', 'play': 'no'},
    7: {'outlook': 'overcast', 'temp': 'cool', 'hum': 'normal', 'wind': 'strong', 'play': 'yes'},
    8: {'outlook': 'sunny', 'temp': 'mild', 'hum': 'high', 'wind': 'weak', 'play': 'no'},
    9: {'outlook': 'sunny', 'temp': 'cool', 'hum': 'normal', 'wind': 'weak', 'play': 'yes'},
    10: {'outlook': 'rain', 'temp': 'mild', 'hum': 'normal', 'wind': 'weak', 'play': 'yes'},
    11: {'outlook': 'sunny', 'temp': 'mild', 'hum': 'normal', 'wind': 'strong', 'play': 'yes'},
    12: {'outlook': 'overcast', 'temp': 'mild', 'hum': 'high', 'wind': 'strong', 'play': 'yes'},
    13: {'outlook': 'overcast', 'temp': 'hot', 'hum': 'normal', 'wind': 'weak', 'play': 'yes'},
    14: {'outlook': 'rain', 'temp': 'mild', 'hum': 'high', 'wind': 'strong', 'play': 'no'},

}


def info(train):  # 定义信息量计算方法，传入训练集
    total, totalzheng, totalfu, info = 0, 0, 0, 0
    for key in train.keys():  # 计算训练样本中yes和no的个数
        total += 1
        if train[key]['play'] == 'yes':
            totalzheng += 1
        elif train[key]['play'] == 'no':
            totalfu += 1
    if totalfu == 0 or totalzheng == 0:
        return [total, totalzheng, totalfu, info]  # 如果全为yes或者全为no，则信息为0
    else:
        bili1 = totalzheng/(totalzheng+totalfu)
        bili2 = totalfu/(totalzheng+totalfu)
        # 计算公式为正数的比例×log2(正数的比例)
        info = bili1*math.log2(bili1)+bili2*math.log2(bili2)
        return [total, totalzheng, totalfu, round(info, 3)*-1]


def parttrain(train, targetattr, mainattr):  # 定义分离数组的方法，传入需要分离的训练集，需要分离出来的属性，和该属性所属的字段
    returndict = {}  # 定义返回字典
    for key in train.keys():
        if train[key][mainattr] == targetattr:  # 如果该条的相应的属性值等于目标属性
            returndict[key] = train[key]
    return returndict


def attrset(train, attr):  # 求该属性在该训练集下的集合
    resset = []
    for key in train.keys():
        resset.append(train[key][attr])  # 直接加入训练集中该属性下的属性值
    resset = set(resset)  # 去除重复值
    return resset


class Tree():
    def __init__(self, root):  # 初始化函数，定义节点值和结点的孩子字典
        self.root = root
        self.child = {}

    def addchild(self, attr, dict):  # 传入属性值和字典，构建孩子字典
        self.child[attr] = dict

    def show(self):  # 返回根节点
        a = {}
        a[self.root] = self.child  # 将孩子字典赋值给根节点
        return a


def maxinfo(train, attrs):  # 定义求该训练集下attrs属性列表中信息增益最大的属性的方法
    maxattr = ''
    maxnum = 0
    for attr in attrs:  # 循环所有的属性
        attrtibutes = attrset(train, attr)  # 求该属性下的属性值
        infoall = info(train)  # 求该属性的信息量
        for shuxing in attrtibutes:  # 对于每个属性值
            attrtrain = parttrain(train, shuxing, attr)  # 先分理处该属性下的训练集
            shuxinginfo = info(attrtrain)  # 求该训练集信息量
            infoall[3] -= (shuxinginfo[0]/infoall[0]) * \
                shuxinginfo[3]  # 信息增益计算公式
        if infoall[3] >= maxnum:  # 找到拥有最大信息增益的属性
            maxnum = round(infoall[3], 3)
            maxattr = attr
    return maxattr


def id3(examples, target, attributes):  # id3方法
    root = Tree(target)  # 定义根节点
    examplesnum = info(examples)  # 先求训练集下的信息量
    if examplesnum[1] != 0 and examplesnum[2] == 0:  # 如果训练集下yes不为零然后no为零，则全为yes，返回
        root.addchild(target, 'yes')
    elif examplesnum[1] == 0 and examplesnum[2] != 0:
        root.addchild(target, 'no')
    elif len(attributes) == 0:
        if examplesnum[1] >= examplesnum[2]:
            root.addchild(target, 'yes')
        else:
            root.addchild(target, 'no')
    else:
        attrs = attrset(examples, target)  # 定义属性集
        attributes.remove(target)
        for attr in attrs:
            nextexample = parttrain(examples, attr, target)
            target2 = maxinfo(nextexample, attributes)
            xunhuanattrs = []
            for i in range(0, len(attributes)):
                xunhuanattrs.append(attributes[i])
            root.addchild(attr, id3(nextexample, target2, xunhuanattrs))
    return root.show()


attrs = ['outlook', 'temp', 'hum', 'wind']
target = maxinfo(train, attrs)
a = id3(train, target, attrs)
print(a)
